import numpy as np
import dgl
import torch
import os
import logging
from sklearn.metrics import average_precision_score, f1_score, roc_auc_score, accuracy_score
import torch.optim as optim
from scipy.io import loadmat
import pandas as pd
import pickle
from sklearn.model_selection import StratifiedKFold, train_test_split
import torch.nn as nn
from sklearn.preprocessing import LabelEncoder, QuantileTransformer
from dgl.dataloading import MultiLayerFullNeighborSampler
from dgl.dataloading import DataLoader
from torch.optim.lr_scheduler import MultiStepLR
from .gtan_model import GraphAttnModel
from . import *

# 计算几何平均值Gmean
def geometric_mean(recall_0, recall_1):
    return np.sqrt(recall_0 * recall_1)

# 计算G-mean
def calculate_g_mean(y_true, y_pred):
    pos_indices = (y_true == 1)
    neg_indices = (y_true == 0)
    
    recall_pos = np.mean(y_pred[pos_indices] == y_true[pos_indices]) if np.any(pos_indices) else 0
    recall_neg = np.mean(y_pred[neg_indices] == y_true[neg_indices]) if np.any(neg_indices) else 0
    
    return geometric_mean(recall_neg, recall_pos)

def gtan_main(feat_df, graph, train_idx, test_idx, labels, args, cat_features):
    # 设置随机种子为72
    args['seed'] = 64
    np.random.seed(args['seed'])
    torch.manual_seed(args['seed'])
    torch.cuda.manual_seed_all(args['seed'])
    
    # 设置日志
    log_dir = os.path.join(os.path.dirname(__file__), "..", "..", "logs")
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    log_file = os.path.join(log_dir, f"gtan_log_{args.get('dataset', 'unknown')}_seed{args['seed']}.txt")
    logging.basicConfig(filename=log_file, level=logging.INFO, 
                        format='%(asctime)s - %(message)s',
                        datefmt='%Y-%m-%d %H:%M:%S')
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    logging.getLogger('').addHandler(console)
    
    device = args['device']
    logging.info(f'Device: {device}')
    graph = graph.to(device)
    oof_predictions = torch.from_numpy(
        np.zeros([len(feat_df), 2])).float().to(device)
    test_predictions = torch.from_numpy(
        np.zeros([len(feat_df), 2])).float().to(device)
    kfold = StratifiedKFold(
        n_splits=args['n_fold'], shuffle=True, random_state=args['seed'])

    y_target = labels.iloc[train_idx].values
    num_feat = torch.from_numpy(feat_df.values).float().to(device)
    cat_feat = {col: torch.from_numpy(feat_df[col].values).long().to(
        device) for col in cat_features}

    y = labels
    labels = torch.from_numpy(y.values).long().to(device)
    loss_fn = nn.CrossEntropyLoss().to(device)
    for fold, (trn_idx, val_idx) in enumerate(kfold.split(feat_df.iloc[train_idx], y_target)):
        logging.info(f'Training fold {fold + 1}')
        
        # 原始训练索引
        original_trn_ind = np.array(train_idx)[trn_idx]
        
        # 划分正负样本
        pos_samples = [i for i in original_trn_ind if y.iloc[i] == 1]
        neg_samples = [i for i in original_trn_ind if y.iloc[i] == 0]
        
        # 如果正样本或负样本数量不足，记录警告
        if len(pos_samples) == 0:
            logging.warning("训练集中没有正样本，无法选择一个正样本")
            pos_samples = []
        if len(neg_samples) == 0:
            logging.warning("训练集中没有负样本，无法选择一个负样本")
            neg_samples = []
        
        # 选择一个正样本和一个负样本
        selected_pos = [pos_samples[0]] if len(pos_samples) > 0 else []
        selected_neg = [neg_samples[0]] if len(neg_samples) > 0 else []
        
        # 新的训练集只包含一个正样本和一个负样本
        trn_ind_list = selected_pos + selected_neg
        
        logging.info(f'训练集正样本数: {len(selected_pos)}, 负样本数: {len(selected_neg)}')
        
        trn_ind = torch.tensor(trn_ind_list).long().to(device)
        val_ind = torch.from_numpy(np.array(train_idx)[val_idx]).long().to(device)
        
        logging.info(f'训练/验证/测试样本数: {len(trn_ind)}, {len(val_ind)}, {len(test_idx)}')

        train_sampler = MultiLayerFullNeighborSampler(args['n_layers'])
        train_dataloader = DataLoader(graph,
                                      trn_ind,
                                      train_sampler,
                                      device=device,
                                      use_ddp=False,
                                      batch_size=args['batch_size'],
                                      shuffle=True,
                                      drop_last=False,
                                      num_workers=0
                                      )
        val_sampler = MultiLayerFullNeighborSampler(args['n_layers'])
        val_dataloader = DataLoader(graph,
                                    val_ind,
                                    val_sampler,
                                    use_ddp=False,
                                    device=device,
                                    batch_size=args['batch_size'],
                                    shuffle=True,
                                    drop_last=False,
                                    num_workers=0,
                                    )
        # TODO
        model = GraphAttnModel(in_feats=feat_df.shape[1],
                               # 为什么要整除4？
                               hidden_dim=args['hid_dim']//4,
                               n_classes=2,
                               heads=[4]*args['n_layers'],  # [4,4,4]
                               activation=nn.PReLU(),
                               n_layers=args['n_layers'],
                               drop=args['dropout'],
                               device=device,
                               gated=args['gated'],
                               ref_df=feat_df,
                               cat_features=cat_feat).to(device)
        lr = args['lr'] * np.sqrt(args['batch_size']/1024)  # 0.00075
        optimizer = optim.Adam(model.parameters(), lr=lr,
                               weight_decay=args['wd'])
        lr_scheduler = MultiStepLR(optimizer=optimizer, milestones=[
                                   4000, 12000], gamma=0.3)

        earlystoper = early_stopper(
            patience=args['early_stopping'], verbose=True)
        start_epoch, max_epochs = 0, 2000
        for epoch in range(start_epoch, args['max_epochs']):
            train_loss_list = []
            # train_acc_list = []
            model.train()
            for step, (input_nodes, seeds, blocks) in enumerate(train_dataloader):
                batch_inputs, batch_work_inputs, batch_labels, lpa_labels = load_lpa_subtensor(num_feat, cat_feat, labels,
                                                                                               seeds, input_nodes, device)
                # (|input|, feat_dim); null; (|batch|,); (|input|,)
                blocks = [block.to(device) for block in blocks]
                train_batch_logits = model(
                    blocks, batch_inputs, lpa_labels, batch_work_inputs)
                mask = batch_labels == 2
                train_batch_logits = train_batch_logits[~mask]
                batch_labels = batch_labels[~mask]
                # batch_labels[mask] = 0

                train_loss = loss_fn(train_batch_logits, batch_labels)
                # backward
                optimizer.zero_grad()
                train_loss.backward()
                optimizer.step()
                lr_scheduler.step()
                train_loss_list.append(train_loss.cpu().detach().numpy())

                if step % 10 == 0:
                    tr_batch_pred = torch.sum(torch.argmax(train_batch_logits.clone(
                    ).detach(), dim=1) == batch_labels) / batch_labels.shape[0]
                    score = torch.softmax(train_batch_logits.clone().detach(), dim=1)[
                        :, 1].cpu().numpy()
                    
                    pred_labels = torch.argmax(train_batch_logits.clone().detach(), dim=1).cpu().numpy()
                    batch_labels_np = batch_labels.cpu().numpy()
                    
                    # 计算正负样本准确率
                    pos_indices = (batch_labels_np == 1)
                    neg_indices = (batch_labels_np == 0)
                    
                    train_acc1 = np.mean(pred_labels[pos_indices] == batch_labels_np[pos_indices]) if np.any(pos_indices) else 0.0
                    train_acc0 = np.mean(pred_labels[neg_indices] == batch_labels_np[neg_indices]) if np.any(neg_indices) else 0.0
                    
                    # 计算G-mean
                    train_gmean = calculate_g_mean(batch_labels_np, pred_labels)

                    try:
                        log_msg = ('In epoch:{:03d}|batch:{:04d}, train_loss:{:4f}, '
                                  'train_ap:{:.4f}, train_acc:{:.4f}, train_auc:{:.4f}, '
                                  'train_acc1:{:.4f}, train_acc0:{:.4f}, train_gmean:{:.4f}')
                        
                        logging.info(log_msg.format(epoch, step,
                                                       np.mean(train_loss_list),
                                                       average_precision_score(batch_labels.cpu().numpy(), score),
                                                       tr_batch_pred.detach(),
                                                       roc_auc_score(batch_labels.cpu().numpy(), score),
                                                       train_acc1, train_acc0, train_gmean))
                    except Exception as e:
                        logging.error(f"Error calculating metrics: {e}")

            # mini-batch for validation
            val_loss_list = 0
            val_acc_list = 0
            val_all_list = 0
            val_batch_all_preds = []
            val_batch_all_labels = []
            model.eval()
            with torch.no_grad():
                for step, (input_nodes, seeds, blocks) in enumerate(val_dataloader):
                    batch_inputs, batch_work_inputs, batch_labels, lpa_labels = load_lpa_subtensor(num_feat, cat_feat, labels,
                                                                                                   seeds, input_nodes, device)

                    blocks = [block.to(device) for block in blocks]
                    val_batch_logits = model(
                        blocks, batch_inputs, lpa_labels, batch_work_inputs)
                    oof_predictions[seeds] = val_batch_logits
                    mask = batch_labels == 2
                    val_batch_logits = val_batch_logits[~mask]
                    batch_labels = batch_labels[~mask]
                    # batch_labels[mask] = 0
                    val_loss_list = val_loss_list + \
                        loss_fn(val_batch_logits, batch_labels)
                    # val_all_list += 1
                    val_batch_pred = torch.sum(torch.argmax(
                        val_batch_logits, dim=1) == batch_labels) / torch.tensor(batch_labels.shape[0])
                    val_acc_list = val_acc_list + val_batch_pred * \
                        torch.tensor(batch_labels.shape[0])
                    val_all_list = val_all_list + batch_labels.shape[0]
                    
                    # 收集预测和标签用于计算整体指标
                    pred_labels = torch.argmax(val_batch_logits, dim=1).cpu().numpy()
                    batch_labels_np = batch_labels.cpu().numpy()
                    val_batch_all_preds.append(pred_labels)
                    val_batch_all_labels.append(batch_labels_np)
                    
                    if step % 10 == 0:
                        score = torch.softmax(val_batch_logits.clone().detach(), dim=1)[
                            :, 1].cpu().numpy()
                            
                        # 计算正负样本准确率
                        pos_indices = (batch_labels_np == 1)
                        neg_indices = (batch_labels_np == 0)
                        
                        val_acc1 = np.mean(pred_labels[pos_indices] == batch_labels_np[pos_indices]) if np.any(pos_indices) else 0.0
                        val_acc0 = np.mean(pred_labels[neg_indices] == batch_labels_np[neg_indices]) if np.any(neg_indices) else 0.0
                        
                        # 计算G-mean
                        val_gmean = calculate_g_mean(batch_labels_np, pred_labels)
                        
                        try:
                            log_msg = ('In epoch:{:03d}|batch:{:04d}, val_loss:{:4f}, val_ap:{:.4f}, '
                                      'val_acc:{:.4f}, val_auc:{:.4f}, val_acc1:{:.4f}, val_acc0:{:.4f}, val_gmean:{:.4f}')
                            
                            logging.info(log_msg.format(epoch,
                                                          step,
                                                          val_loss_list/val_all_list,
                                                          average_precision_score(batch_labels_np, score),
                                                          val_batch_pred.detach(),
                                                          roc_auc_score(batch_labels_np, score),
                                                          val_acc1, val_acc0, val_gmean))
                        except Exception as e:
                            logging.error(f"Error calculating validation metrics: {e}")
                
                # 计算整体验证集指标
                if len(val_batch_all_labels) > 0 and len(val_batch_all_preds) > 0:
                    all_val_labels = np.concatenate(val_batch_all_labels)
                    all_val_preds = np.concatenate(val_batch_all_preds)
                    
                    pos_indices = (all_val_labels == 1)
                    neg_indices = (all_val_labels == 0)
                    
                    val_acc1 = np.mean(all_val_preds[pos_indices] == all_val_labels[pos_indices]) if np.any(pos_indices) else 0.0
                    val_acc0 = np.mean(all_val_preds[neg_indices] == all_val_labels[neg_indices]) if np.any(neg_indices) else 0.0
                    val_gmean = calculate_g_mean(all_val_labels, all_val_preds)
                    
                    logging.info(f'Epoch {epoch} validation metrics - ACC1: {val_acc1:.4f}, ACC0: {val_acc0:.4f}, G-mean: {val_gmean:.4f}')

            # val_acc_list/val_all_list, model)
            earlystoper.earlystop(val_loss_list/val_all_list, model)
            if earlystoper.is_earlystop:
                logging.info("Early Stopping!")
                break
        logging.info("Best val_loss is: {:.7f}".format(earlystoper.best_cv))
        test_ind = torch.from_numpy(np.array(test_idx)).long().to(device)
        test_sampler = MultiLayerFullNeighborSampler(args['n_layers'])
        test_dataloader = DataLoader(graph,
                                     test_ind,
                                     test_sampler,
                                     use_ddp=False,
                                     device=device,
                                     batch_size=args['batch_size'],
                                     shuffle=True,
                                     drop_last=False,
                                     num_workers=0,
                                     )
        b_model = earlystoper.best_model.to(device)
        b_model.eval()
        test_batch_all_preds = []
        test_batch_all_labels = []
        with torch.no_grad():
            for step, (input_nodes, seeds, blocks) in enumerate(test_dataloader):
                # print(input_nodes)
                batch_inputs, batch_work_inputs, batch_labels, lpa_labels = load_lpa_subtensor(num_feat, cat_feat, labels,
                                                                                               seeds, input_nodes, device)

                blocks = [block.to(device) for block in blocks]
                test_batch_logits = b_model(
                    blocks, batch_inputs, lpa_labels, batch_work_inputs)
                test_predictions[seeds] = test_batch_logits
                
                # 收集预测和标签用于计算整体指标
                pred_labels = torch.argmax(test_batch_logits, dim=1).cpu().numpy()
                batch_labels_np = batch_labels.cpu().numpy()
                test_batch_all_preds.append(pred_labels)
                test_batch_all_labels.append(batch_labels_np)
                
                test_batch_pred = torch.sum(torch.argmax(
                    test_batch_logits, dim=1) == batch_labels) / torch.tensor(batch_labels.shape[0])
                if step % 10 == 0:
                    logging.info('In test batch:{:04d}'.format(step))
        
        # 计算整体测试集指标
        if len(test_batch_all_labels) > 0 and len(test_batch_all_preds) > 0:
            all_test_labels = np.concatenate(test_batch_all_labels)
            all_test_preds = np.concatenate(test_batch_all_preds)
            
            pos_indices = (all_test_labels == 1)
            neg_indices = (all_test_labels == 0)
            
            test_acc1 = np.mean(all_test_preds[pos_indices] == all_test_labels[pos_indices]) if np.any(pos_indices) else 0.0
            test_acc0 = np.mean(all_test_preds[neg_indices] == all_test_labels[neg_indices]) if np.any(neg_indices) else 0.0
            test_gmean = calculate_g_mean(all_test_labels, all_test_preds)
            
            logging.info(f'Overall test metrics - ACC1: {test_acc1:.4f}, ACC0: {test_acc0:.4f}, G-mean: {test_gmean:.4f}')
            
    mask = y_target == 2
    y_target[mask] = 0
    my_ap = average_precision_score(y_target, torch.softmax(
        oof_predictions, dim=1).cpu()[train_idx, 1])
    logging.info("NN out of fold AP is: {:.4f}".format(my_ap))
    b_models, val_gnn_0, test_gnn_0 = earlystoper.best_model.to(
        'cpu'), oof_predictions, test_predictions

    test_score = torch.softmax(test_gnn_0, dim=1)[test_idx, 1].cpu().numpy()
    y_target = labels[test_idx].cpu().numpy()
    test_score1 = torch.argmax(test_gnn_0, dim=1)[test_idx].cpu().numpy()

    mask = y_target != 2
    test_score = test_score[mask]
    y_target = y_target[mask]
    test_score1 = test_score1[mask]

    # 计算最终测试指标
    test_auc = roc_auc_score(y_target, test_score)
    test_f1 = f1_score(y_target, test_score1, average="macro")
    test_ap = average_precision_score(y_target, test_score)
    
    # 计算正负样本准确率
    pos_indices = (y_target == 1)
    neg_indices = (y_target == 0)
    
    test_acc1 = np.mean(test_score1[pos_indices] == y_target[pos_indices]) if np.any(pos_indices) else 0.0
    test_acc0 = np.mean(test_score1[neg_indices] == y_target[neg_indices]) if np.any(neg_indices) else 0.0
    
    # 计算G-mean
    test_gmean = calculate_g_mean(y_target, test_score1)
    
    logging.info("Final test AUC: {:.4f}".format(test_auc))
    logging.info("Final test F1: {:.4f}".format(test_f1))
    logging.info("Final test AP: {:.4f}".format(test_ap))
    logging.info("Final test ACC1: {:.4f}".format(test_acc1))
    logging.info("Final test ACC0: {:.4f}".format(test_acc0))
    logging.info("Final test G-mean: {:.4f}".format(test_gmean))


def load_gtan_data(dataset: str, test_size: float):
    """
    Load graph, feature, and label given dataset name
    :param dataset: the dataset name
    :param test_size: the size of test set
    :returns: feature, label, graph, category features
    """
    # prefix = './antifraud/data/'
    prefix = os.path.join(os.path.dirname(__file__), "..", "..", "data/")
    if dataset == "S-FFSD":
        cat_features = ["Target", "Location", "Type"]

        df = pd.read_csv(prefix + "S-FFSDneofull.csv")
        df = df.loc[:, ~df.columns.str.contains('Unnamed')]
        data = df[df["Labels"] <= 2]
        data = data.reset_index(drop=True)
        out = []
        alls = []
        allt = []
        pair = ["Source", "Target", "Location", "Type"]
        for column in pair:
            src, tgt = [], []
            edge_per_trans = 3
            for c_id, c_df in data.groupby(column):
                c_df = c_df.sort_values(by="Time")
                df_len = len(c_df)
                sorted_idxs = c_df.index
                src.extend([sorted_idxs[i] for i in range(df_len)
                            for j in range(edge_per_trans) if i + j < df_len])
                tgt.extend([sorted_idxs[i+j] for i in range(df_len)
                            for j in range(edge_per_trans) if i + j < df_len])
            alls.extend(src)
            allt.extend(tgt)
        alls = np.array(alls)
        allt = np.array(allt)
        g = dgl.graph((alls, allt))

        cal_list = ["Source", "Target", "Location", "Type"]
        for col in cal_list:
            le = LabelEncoder()
            data[col] = le.fit_transform(data[col].apply(str).values)
        feat_data = data.drop("Labels", axis=1)
        labels = data["Labels"]
        ###
        feat_data.to_csv(prefix + "S-FFSD_feat_data.csv", index=None)
        labels.to_csv(prefix + "S-FFSD_label_data.csv", index=None)
        ###
        index = list(range(len(labels)))
        g.ndata['label'] = torch.from_numpy(
            labels.to_numpy()).to(torch.long)
        g.ndata['feat'] = torch.from_numpy(
            feat_data.to_numpy()).to(torch.float32)
        graph_path = prefix+"graph-{}.bin".format(dataset)
        dgl.data.utils.save_graphs(graph_path, [g])

        train_idx, test_idx, y_train, y_test = train_test_split(index, labels, stratify=labels, test_size=test_size/2,
                                                                random_state=72, shuffle=True)

    elif dataset == "yelp":
        cat_features = []
        data_file = loadmat(prefix + 'YelpChi.mat')
        labels = pd.DataFrame(data_file['label'].flatten())[0]
        feat_data = pd.DataFrame(data_file['features'].todense().A)
        # load the preprocessed adj_lists
        with open(prefix + 'yelp_homo_adjlists.pickle', 'rb') as file:
            homo = pickle.load(file)
        file.close()
        index = list(range(len(labels)))
        train_idx, test_idx, y_train, y_test = train_test_split(index, labels, stratify=labels, test_size=test_size,
                                                                random_state=72, shuffle=True)
        src = []
        tgt = []
        for i in homo:
            for j in homo[i]:
                src.append(i)  # src是出发点
                tgt.append(j)  # tgt是被指向点
        src = np.array(src)
        tgt = np.array(tgt)
        g = dgl.graph((src, tgt))
        g.ndata['label'] = torch.from_numpy(labels.to_numpy()).to(torch.long)
        g.ndata['feat'] = torch.from_numpy(
            feat_data.to_numpy()).to(torch.float32)
        graph_path = prefix + "graph-{}.bin".format(dataset)
        dgl.data.utils.save_graphs(graph_path, [g])

    elif dataset == "amazon":
        cat_features = []
        data_file = loadmat(prefix + 'Amazon.mat')
        labels = pd.DataFrame(data_file['label'].flatten())[0]
        feat_data = pd.DataFrame(data_file['features'].todense().A)
        # load the preprocessed adj_lists
        with open(prefix + 'amz_homo_adjlists.pickle', 'rb') as file:
            homo = pickle.load(file)
        file.close()
        index = list(range(3305, len(labels)))
        train_idx, test_idx, y_train, y_test = train_test_split(index, labels[3305:], stratify=labels[3305:],
                                                                test_size=test_size, random_state=72, shuffle=True)
        src = []
        tgt = []
        for i in homo:
            for j in homo[i]:
                src.append(i)
                tgt.append(j)
        src = np.array(src)
        tgt = np.array(tgt)
        g = dgl.graph((src, tgt))
        g.ndata['label'] = torch.from_numpy(labels.to_numpy()).to(torch.long)
        g.ndata['feat'] = torch.from_numpy(
            feat_data.to_numpy()).to(torch.float32)
        graph_path = prefix + "graph-{}.bin".format(dataset)
        dgl.data.utils.save_graphs(graph_path, [g])

    return feat_data, labels, train_idx, test_idx, g, cat_features
